from sklearn.model_selection import train_test_split
import pandas as pd
from sklearn.preprocessing import Normalizer, OneHotEncoder, StandardScaler,MinMaxScaler
import matplotlib.pyplot as plt
import lime
import lime.lime_tabular
from __future__ import print_function
import math
from sklearn.metrics import f1_score,accuracy_score
from sklearn.neural_network import MLPClassifier
from sklearn.utils import shuffle
import numpy as np
The main used libraries include sklearn, matplotlib, pandas, numpy and lime.
datasets_1 = np.genfromtxt('./Data/adult.data', delimiter=', ', dtype=str)
datasets_2 = np.genfromtxt('./Data/adult.test', delimiter=', ', dtype=str,autostrip = True )
for i, n in enumerate(datasets_2[:,14]):
datasets_2[i,14] = n.strip('.')
feature_and_labels_names=["Age", "Workclass", "fnlwgt", "Education", "Education-Num", "Marital Status","Occupation",
"Relationship", "Race", "Sex", "Capital Gain", "Capital Loss","Hours per week","Country","income"]
alldata =np.concatenate((datasets_1, datasets_2), axis=0)
df = pd.DataFrame(alldata,columns=feature_and_labels_names)
from numpy import unique
for i in range(alldata.shape[1]):
if len(unique(alldata[:, i])) <=1:
print(i)
else:
print(i, len(unique(alldata[:, i])))
0 74 1 9 2 28523 3 16 4 16 5 7 6 15 7 6 8 5 9 2 10 123 11 99 12 96 13 42 14 2
This steps shows kinds of value of each feature.
print(df.shape)
for a in feature_and_labels_names:
df = df[df[a].str.contains('\?') == False]
print(df.shape)
df.drop_duplicates()
df.dropna()
df.isnull().sum()
print(df.shape)
(48842, 15) (45222, 15) (45222, 15)
Missing values are removed,and duplicated rows are removed
df['fnlwgt']=df['fnlwgt'].astype(str).astype(int)
print(df['fnlwgt'])
ax=df.hist(column=['fnlwgt'],density=True,bins=500)
0 77516
1 83311
2 215646
3 234721
4 338409
...
48836 245211
48837 215419
48839 374983
48840 83891
48841 182148
Name: fnlwgt, Length: 45222, dtype: int32
Visuallizing the disturibution of value of feature-Human weight
q_hi = df["fnlwgt"].quantile(0.995)
df = df[(df["fnlwgt"] < q_hi)]
ax=df.hist(column=['fnlwgt'],density=True,bins=500)
print(df.shape)
(44995, 15)
Removing outliers
df=df.drop('Education-Num',axis = 1)
np.random.seed(5)
data = df.to_numpy()
np.random.shuffle(data)
y = data[:,-1]
data =np.delete(data,[13],axis = 1)
Feature selection is done in this step that Education-num is deleted for the repetition.
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import LabelEncoder
categorical_features=[0,1,3,4,5,6,7,8,12]
categorical_names = {}
for feature in categorical_features:
le = LabelEncoder()
le.fit(data[:, feature])
data[:, feature] = le.transform(data[:, feature])
categorical_names[feature] = le.classes_
t = [('cat', OneHotEncoder(),[0,1,3,4,5,6,7,8,12]),('num',MinMaxScaler(),
[2,9,10,11])]
transformer = ColumnTransformer(transformers=t)
le= LabelEncoder()
y = le.fit_transform(y)
class_names = le.classes_
data = data.astype(float)
data_1 = data[0:3000]
y_1 = y[0:3000]
Data transform is taken in this step. Both categorical and neumerical features are transformed.
train, test, labels_train, labels_test = train_test_split(data_1,y_1,train_size=0.80,random_state=1)
transformer.fit(data)
Encoded_train = transformer.transform(train)
Data split is done in this step with a ratio of 8:2 randomly
from sklearn.metrics import f1_score
from sklearn.neural_network import MLPClassifier
mlp_qs=MLPClassifier(max_iter=2000)
parameter_space ={
'hidden_layer_sizes': [(5,),(10,),(15,),(20,),(40,)],
'activation':['tanh','relu'],
'solver':['sgd'],
'alpha':[0.001,0.005,0.01],
}
# 'learning_rate':['constant,'adaptive'],
from sklearn.model_selection import GridSearchCV
clf = GridSearchCV(mlp_qs,parameter_space,n_jobs=-1,cv = 2)
clf.fit(Encoded_train,labels_train)
hidden_layer_sizes=(10,), random_state=3,max_iter=500,
activation='relu',alpha =0.001,solver='sgd'
File "C:\Users\wjiames\AppData\Local\Temp/ipykernel_20572/3174174614.py", line 2 hyperparameters tune ^ SyntaxError: invalid syntax
Grid reseach is done in this step to find the best hyper-parameters for neural networks. The best parameter is "hidden_layer_sizes=(15,), random_state=3,max_iter=500, activation='relu',alpha =0.001,solver='sgd'"
clf = MLPClassifier(hidden_layer_sizes=(15,), random_state=3,max_iter=1000,
activation='relu',alpha =0.001,solver='sgd')
clf.fit(Encoded_train, labels_train)
y_true = clf.predict(transformer.transform(test))
predict_fn = lambda x: clf.predict_proba(transformer.transform(x)).astype(float)
print(f1_score(y_true, labels_test, average='macro'))
accuracy_score(y_true, labels_test)
0.7314746710564721
0.8083333333333333
This is the model performance for the classifer f.
import lime
import lime.lime_tabular
feature_names=["Age", "Workclass", "fnlwgt", "Education", "Marital Status","Occupation",
"Relationship", "Race", "Sex", "Capital Gain", "Capital Loss","Hours per week","Country"]
explainer = lime.lime_tabular.LimeTabularExplainer(data,feature_names = feature_names,class_names=class_names,
categorical_features=categorical_features,
categorical_names=categorical_names)
a=np.arange(5) #Generating 5 examples.
for i in a:
b = explainer.explain_instance(data_1[i],predict_fn, num_features=5, top_labels=1)
b.show_in_notebook(predict_proba=True,show_predicted_value=True)
import numpy as np
A=np.arange(len(data_1))
firkeys = [i for i in list(range(13))]
fir ={}
for i in range(len(firkeys)):
fir[i]=[]
for i in A:
b = explainer.explain_instance(data_1[i],predict_fn, num_features=5, top_labels=1)
for v in b.local_exp:
g=b.local_exp[v]
d = np.asarray(g)
for n,i in d:
fir[n].append(abs(i))
# np.save('data/classfir.npy',fir)
fir =np.load('data/classfir.npy',allow_pickle=True).item()
d = []
l = []
for values in fir.values():
l.append(len(values))
d.append(values)
Generating feature-based explanations for all instances and aggregating them to get global feature importance.
fig,ax = plt.subplots()
plt.rcParams["figure.figsize"] = (16,9)
labels = ["Age", "Workclass", "Weight", "Education", "Marital Status","Occupation",
"Relationship", "Race", "Sex", "Capital Gain", "Capital Loss","Hours per week","Country"]
x = np.arange(len(labels))
ax.boxplot(d,labels = labels,showfliers = False,patch_artist=True,positions =x)
ax.set_ylabel('Feature importance',fontsize = 30,color='C0')
ax.grid(axis='x')
plt.xticks(fontsize =20,rotation = 50)
plt.yticks(fontsize=20)
ax1=ax.twinx()
ax1.plot(x,l,color = 'C1',marker='*',ms=20)
ax1.set_ylabel('Feature frequency',fontsize = 30,color='C1')
plt.yticks(fontsize=20)
plt.show()
fig.figure.savefig('./Downloads/featureaverageimportance.png',dpi=450,bbox_inches = 'tight')
Drawing figure for feature importance, including error, average importance and frequency of top five feaure importance at each local explanation.It is apparent that the features of weight and hours per week are unimportant.
data_2 = np.delete(data_1,[2,11],axis = 1)
data = np.delete(data,[2,11],axis = 1)
r_train, r_test, r_labels_train, r_labels_test = train_test_split(data_2,y_1,train_size=0.80,
random_state=1)
x = [('cat', OneHotEncoder(),[0,1,2,3,4,5,6,7,10]),('num',MinMaxScaler(),
[8,9])]
r_transformer = ColumnTransformer(transformers=x)
r_transformer.fit(data)
Encoded_r_train = r_transformer.transform(r_train)
Based on the figure above, two features are removed(Weight and hours per week) and then data is updated.
from sklearn.metrics import accuracy_score
clf1 = MLPClassifier(hidden_layer_sizes=(15,), random_state=3,max_iter=1000,
activation='relu',alpha =0.001,solver='sgd')
clf1.fit(Encoded_r_train, r_labels_train)
r_y_true = clf1.predict(r_transformer.transform(r_test))
print(f1_score(r_y_true, labels_test, average='macro'))
print(accuracy_score(r_y_true, labels_test))
0.7304157273257555 0.81
The updated data is used to retrain the model(classifier fa). The performance is measured by Macro-F1 and Accuracy.
x1 = data[3000:18000]
x2 = data[18000:33000]
y1 = y[3000:18000]
y2 = np.random.randint(0,2,size=15000)
print(y2.shape,y1.shape)
x =np.concatenate((x1,x2),axis=0)
y =np.concatenate((y1,y2),axis =0)
x = x.reshape(30,1000,11)
y = y.reshape(30,1000,)
datakeys = [str(i) for i in list(range(30))]
datavalue ={}
for i in range(len(datakeys)):
datavalue[i]=[0,0]
(15000,) (15000,)
30000 data instances are sampled from the remaining data, and the label of 15000 instances are changed randomly, that is corrupted.
def accuracy(x,y,mode):
classifier =MLPClassifier(hidden_layer_sizes=(15,), random_state=3,max_iter=1000,
activation='relu',alpha =0.001,solver='sgd')
encodedtrain = r_transformer.transform(x)
if mode == 1:
classifier.partial_fit(encodedtrain,y, classes=list(range(2))) # Gradient-based method
elif mode == 0:
classifier.fit(encodedtrain,y) # retrain the model
r_y_true = classifier.predict(r_transformer.transform(r_test))
accuracy = f1_score(r_y_true, labels_test, average='micro') # f1 micro accuracy = accuracy
macro = f1_score(r_y_true, labels_test, average='macro')
a = [accuracy,macro]
return a
inicial_accuracy =accuracy(r_train,r_labels_train,0)
def data_value(x,y,test=r_train,label=r_labels_train):
metric = [0,0]
each_accuracy=[inicial_accuracy]
data_i = test
y_1 = r_labels_train
for i in np.random.permutation(len(x)): # Monte-Carlo method
data_i = np.concatenate((data_i,x[i]), axis=0)
y_1 = np.concatenate((y_1,y[i]), axis=0)
each_accuracy.append(accuracy(data_i,y_1,1))
metric = np.array(each_accuracy[-1])-np.array(each_accuracy[-2])
datavalue[i][0] += metric[0]
datavalue[i][1] += metric[1]
pass
# Calculating the marginal contribution of each data group
### Reaching convergence
n= 300
# each_turn_data=[]
for j in range(n):
data_value(x,y)
### Caculating the average value
group_value =[]
for i in range(len(x)):
group_value.append(np.array(datavalue[i]).astype(float)/n)
{0: [-0.42999999999999994, -0.6599349494064913], 1: [0.019999999999999796, -0.27629430641449276], 2: [-0.05499999999999983, -0.31582888427970185], 3: [0.014999999999999958, -0.3029583771878191], 4: [-0.06666666666666721, -0.2932980038423557], 5: [0.3833333333333333, 0.07131656561095567], 6: [0.2533333333333333, 0.01852694793655346], 7: [-0.1333333333333332, -0.3624080241462838], 8: [0.15666666666666645, -0.014852949327723963], 9: [0.32833333333333337, -0.005189065635165135], 10: [0.1799999999999995, -0.0227903487171926], 11: [0.23333333333333361, 0.000841991864602154], 12: [0.15500000000000014, 0.025095590873332962], 13: [0.2816666666666664, 0.041822274679304605], 14: [-0.05166666666666697, -0.33766965847460695], 15: [0.0033333333333331883, 0.020916620687084053], 16: [0.0016666666666668162, 0.006449373388196955], 17: [0.023333333333333095, 0.007216040166735327], 18: [0.010000000000000064, 0.029269831760748544], 19: [-0.0249999999999998, -0.010281998138405357], 20: [0.0033333333333335213, 0.005825000308292305], 21: [-0.3166666666666668, -0.31306646091251006], 22: [-0.3033333333333334, -0.322615307884046], 23: [0.016666666666666663, 0.01023489986737014], 24: [-0.34500000000000025, -0.32409369973895136], 25: [-0.325, -0.33230458673902963], 26: [0.03166666666666673, 0.010025132491535804], 27: [-0.0050000000000004485, 0.007897497709931534], 28: [-0.3566666666666665, -0.3323137245245185], 29: [-1.0200000000000002, -0.983717457150619]}
{0: [-0.41999999999999993, -0.6607365422673528], 1: [0.019999999999999796, -0.28129437303939997], 2: [-0.009999999999999842, -0.3026242288091748], 3: [0.0683333333333333, -0.2915032257203313], 4: [-0.06666666666666721, -0.29916900433012117], 5: [0.3883333333333333, 0.07331772345799448], 6: [0.265, 0.01780240325608201], 7: [-0.10666666666666658, -0.3624080241462838], 8: [0.16666666666666635, -0.015239811136487924], 9: [0.3350000000000001, -0.027169458182128958], 10: [0.1833333333333328, -0.027215959196249806], 11: [0.2350000000000001, 0.0014388616601476167], 12: [0.17166666666666686, 0.03345911542985153], 13: [0.32333333333333303, 0.053690016125694184], 14: [-0.041666666666667074, -0.3395228954750097], 15: [0.016666666666666496, 0.025475893660616478], 16: [-0.3483333333333332, -0.325570632012425], 17: [0.026666666666666394, 0.008897168551616097], 18: [0.010000000000000064, 0.034925989008023184], 19: [-0.023333333333333095, -0.004614691007120086], 20: [0.0016666666666668162, 0.005157759801914241], 21: [-0.3150000000000001, -0.31234654119310556], 22: [-0.3033333333333334, -0.322615307884046], 23: [0.013333333333333475, 0.00896591225607879], 24: [-0.34500000000000025, -0.32409369973895136], 25: [-0.325, -0.33230458673902963], 26: [0.04166666666666674, 0.01571443863563382], 27: [-0.006666666666667043, 0.0072644781201213204], 28: [-0.3583333333333332, -0.33953332099886585], 29: [-1.0166666666666668, -0.9818937997392945]}
{0: [-0.3833333333333333, -0.6499481145083521], 1: [0.019999999999999796, -0.2871653735271654], 2: [-0.008333333333333137, -0.3076828588780422], 3: [0.07500000000000001, -0.31348361826729515], 4: [-0.0650000000000005, -0.2985017638237431], 5: [0.3899999999999998, 0.07391459325353994], 6: [0.28500000000000003, 0.027847056197481346], 7: [-0.09499999999999997, -0.3620749662356433], 8: [0.16833333333333295, -0.019516865137054806], 9: [0.39, -0.017239835628477085], 10: [0.1933333333333328, -0.028510630199259646], 11: [0.255, -0.0021078279866219574], 12: [0.18500000000000016, 0.03902023959308354], 13: [0.33499999999999974, 0.05247183517719922], 14: [0.004999999999999616, -0.32027036266938613], 15: [0.019999999999999796, 0.027309829656207485], 16: [-0.3483333333333332, -0.325570632012425], 17: [0.04166666666666641, 0.009797765171063288], 18: [0.013333333333333475, 0.03675960952392221], 19: [-0.37333333333333313, -0.336634696407742], 20: [3.3306690738754696e-16, 0.004598194024957147], 21: [-0.3100000000000001, -0.3096858246201221], 22: [-0.3033333333333334, -0.322615307884046], 23: [0.010000000000000175, 0.00757534884419514], 24: [-0.34500000000000025, -0.32409369973895136], 25: [-0.325, -0.326648429491755], 26: [0.04000000000000015, 0.01501933902399094], 27: [-0.008333333333333637, 0.006378056321212078], 28: [-0.3499999999999999, -0.3409194739363718], 29: [-1.0166666666666668, -0.9818937997392945]}
{0: [-0.3833333333333333, -0.6558191149961176], 1: [0.06333333333333313, -0.2781063699679756], 2: [-0.0016666666666664276, -0.32978514998416675], 3: [0.1233333333333334, -0.3006850247915598], 4: [-0.06166666666666709, -0.2971674770028669], 5: [0.3899999999999998, 0.06891452662863273], 6: [0.29500000000000015, 0.02652039493188113], 7: [-0.07333333333333325, -0.3511884950156787], 8: [0.17999999999999966, -0.014651150866630935], 9: [0.3916666666666665, -0.016642965832931622], 10: [0.23833333333333284, -0.01910015121191272], 11: [0.2649999999999999, -0.003961064987024698], 12: [0.19166666666666687, 0.03586157850576416], 13: [0.34499999999999964, 0.05208497336843526], 14: [0.03166666666666623, -0.32032265399615706], 15: [0.01833333333333309, 0.026476147909192282], 16: [-0.3483333333333332, -0.325570632012425], 17: [0.04166666666666641, 0.009797765171063288], 18: [0.013333333333333475, 0.04241576677119685], 19: [-0.37333333333333313, -0.33863074652961245], 20: [3.3306690738754696e-16, 0.004598194024957147], 21: [-0.6566666666666668, -0.6408777099636895], 22: [-0.2850000000000001, -0.31561966891546234], 23: [0.01833333333333348, 0.01274572543095348], 24: [-0.34500000000000025, -0.32409369973895136], 25: [-0.33, -0.328548398473303], 26: [0.03833333333333344, 0.014185657276975738], 27: [-0.010000000000000342, 0.005710815814834014], 28: [-0.3483333333333332, -0.3400857921893566], 29: [-1.0166666666666668, -0.9818937997392945]}
{0: [-0.3766666666666667, -0.6580109712512493], 1: [0.07333333333333314, -0.2785428779044835], 2: [0.051666666666666916, -0.31231385871041234], 3: [0.13333333333333341, -0.3010226052228586], 4: [-0.0583333333333339, -0.2959730800289142], 5: [0.4099999999999998, 0.07896731610158209], 6: [0.29833333333333356, 0.02785468175275735], 7: [-0.06666666666666654, -0.37316888756264255], 8: [0.19333333333333297, -0.009319431450412974], 9: [0.3983333333333332, -0.02515458351515254], 10: [0.25499999999999945, -0.02114536300728137], 11: [0.2649999999999999, -0.003961064987024698], 12: [0.1900000000000004, 0.02943101224104161], 13: [0.37999999999999967, 0.05834664538117784], 14: [0.08166666666666617, -0.31367728702969905], 15: [0.02166666666666639, 0.02831373464916237], 16: [-0.3499999999999999, -0.3264124502909902], 17: [-0.306666666666667, -0.3239972598311527], 18: [0.013333333333333475, 0.04241576677119685], 19: [-0.3749999999999997, -0.3346933780518414], 20: [0.010000000000000342, 0.008528535348443778], 21: [-0.6583333333333335, -0.6415977436987987], 22: [-0.2700000000000001, -0.31044451167323356], 23: [0.016666666666666774, 0.015881250681691106], 24: [-0.34500000000000025, -0.32409369973895136], 25: [-0.3283333333333333, -0.33168392372404065], 26: [0.03333333333333344, 0.01155749099632053], 27: [-0.011666666666667047, 0.004990890744889731], 28: [-0.3349999999999998, -0.3327449398843519], 29: [-1.0166666666666668, -0.9818937997392945]}
{0: [-0.3683333333333334, -0.6590684767524924], 1: [0.08666666666666645, -0.2729817537412515], 2: [0.061666666666666925, -0.3326131228724954], 3: [0.13500000000000012, -0.3053275183286903], 4: [-0.010000000000000564, -0.28023756691906276], 5: [0.4316666666666664, 0.07322137408952778], 6: [0.3533333333333335, 0.0443338552643851], 7: [-0.05666666666666664, -0.3750221245630453], 8: [0.19666666666666616, -0.008125034476460291], 9: [0.3999999999999997, -0.024557713719607077], 10: [0.27499999999999947, -0.011115918277952663], 11: [0.2649999999999999, -0.009832065474790153], 12: [0.22333333333333377, 0.039184027504649555], 13: [0.3899999999999997, 0.056922770902747555], 14: [0.08999999999999958, -0.3103428357306689], 15: [0.019999999999999685, 0.027646494142784306], 16: [-0.3499999999999999, -0.3264124502909902], 17: [-0.30833333333333346, -0.32459412962669815], 18: [0.013333333333333475, 0.04241576677119685], 19: [-0.3799999999999996, -0.33659632931716527], 20: [0.010000000000000342, 0.008528535348443778], 21: [-0.6583333333333335, -0.6415977436987987], 22: [-0.26333333333333336, -0.309938894416376], 23: [0.016666666666666774, 0.015881250681691106], 24: [-0.34500000000000025, -0.32409369973895136], 25: [-0.3083333333333333, -0.3309788056682412], 26: [0.03333333333333344, 0.01155749099632053], 27: [-0.011666666666667047, 0.004990890744889731], 28: [-0.3349999999999998, -0.3327449398843519], 29: [-1.3633333333333335, -1.3116320320594044]}
{0: [-0.3550000000000001, -0.6535075284550205], 1: [0.08999999999999964, -0.2717873567672988], 2: [0.07833333333333353, -0.33219644910264257], 3: [0.14333333333333353, -0.30199306702966017], 4: [-0.0016666666666672603, -0.2813196627548748], 5: [0.48999999999999977, 0.09315868240228237], 6: [0.3633333333333335, 0.03175341644305768], 7: [-0.02499999999999991, -0.37544385567022515], 8: [0.20499999999999946, -0.010182872608923088], 9: [0.4099999999999996, -0.026410950720009818], 10: [0.27499999999999947, -0.016986918765718118], 11: [0.31833333333333325, 0.003884287927661878], 12: [0.22333333333333377, 0.03418396087974235], 13: [0.07999999999999963, -0.2650067058314942], 14: [0.08999999999999958, -0.31621383621843435], 15: [0.02166666666666628, 0.02834172816983238], 16: [-0.34833333333333344, -0.31998188402626765], 17: [-0.30666666666666687, -0.3267603713269268], 18: [0.011666666666666992, 0.041856200994239756], 19: [-0.3833333333333328, -0.3378653169284566], 20: [0.008333333333333748, 0.007808599531477833], 21: [-0.6583333333333335, -0.6415977436987987], 22: [-0.26166666666666666, -0.30924356503154043], 23: [0.013333333333333475, 0.014490687269807456], 24: [-0.34500000000000025, -0.32409369973895136], 25: [-0.3083333333333333, -0.3309788056682412], 26: [0.03166666666666673, 0.010890250489942466], 27: [-0.006666666666667043, 0.004381669968860857], 28: [-0.3299999999999998, -0.33024386726120863], 29: [-1.3633333333333335, -1.3116320320594044]}
{0: [-0.3533333333333334, -0.6585661585238879], 1: [0.09499999999999964, -0.2749896317748477], 2: [0.10833333333333356, -0.33577695614019315], 3: [0.16000000000000014, -0.2950413446618017], 4: [-0.0016666666666672603, -0.28719066324264025], 5: [0.17999999999999972, -0.2287707943319594], 6: [0.3750000000000002, 0.03619902776482836], 7: [-0.016666666666666496, -0.37793176858922156], 8: [0.20499999999999946, -0.015182939233830295], 9: [0.4199999999999996, -0.02677312683641736], 10: [0.27666666666666595, -0.016390048970172655], 11: [0.3649999999999999, 0.01926866795101062], 12: [0.2750000000000004, 0.04729504855465472], 13: [0.10166666666666624, -0.2577476561153795], 14: [0.1016666666666663, -0.33567372625684644], 15: [0.03166666666666629, 0.03455356590561898], 16: [-0.34500000000000003, -0.31979783007204493], 17: [-0.30833333333333357, -0.3273943349809593], 18: [0.011666666666666992, 0.041856200994239756], 19: [-0.3799999999999995, -0.3376214087898117], 20: [0.006666666666667043, 0.007174287838780158], 21: [-0.6583333333333335, -0.6415977436987987], 22: [-0.26166666666666666, -0.30924356503154043], 23: [0.013333333333333475, 0.014490687269807456], 24: [-0.34833333333333355, -0.32548447380933787], 25: [-0.3066666666666666, -0.33013943154457026], 26: [0.03166666666666673, 0.016546407737217106], 27: [-0.008333333333333748, 0.003542295845189891], 28: [-0.3299999999999998, -0.33024386726120863], 29: [-1.3633333333333335, -1.3116320320594044]}
{0: [-0.3416666666666667, -0.6541205472021172], 1: [0.10499999999999965, -0.2797950051924853], 2: [0.11000000000000026, -0.34008186924602485], 3: [0.21166666666666678, -0.28721366911280455], 4: [0.013333333333332753, -0.28093438556033273], 5: [0.18833333333333302, -0.24546701810888288], 6: [0.3766666666666667, 0.036795897560373825], 7: [-0.011666666666666492, -0.38113404359677044], 8: [0.22999999999999948, -0.012897042624246124], 9: [0.46333333333333293, -0.015624334750414748], 10: [0.28333333333333266, -0.019548710057492036], 11: [0.40833333333333327, 0.03269873106725946], 12: [0.2750000000000004, 0.04142404806688926], 13: [0.11999999999999966, -0.24854397464279043], 14: [0.1016666666666663, -0.33567372625684644], 15: [0.03999999999999959, 0.03992231682976727], 16: [-0.34500000000000003, -0.31979783007204493], 17: [-0.2950000000000002, -0.32673537637451194], 18: [0.021666666666666945, 0.04418413352410394], 19: [-0.3833333333333329, -0.3392892897720922], 20: [0.006666666666667043, 0.007174287838780158], 21: [-0.6583333333333335, -0.6415977436987987], 22: [-0.25833333333333336, -0.3075776448587303], 23: [0.010000000000000175, 0.012824767096997314], 24: [-0.34833333333333355, -0.32548447380933787], 25: [-0.3066666666666666, -0.33013943154457026], 26: [0.03166666666666673, 0.016546407737217106], 27: [-0.3583333333333338, -0.32847770955543204], 28: [-0.3316666666666665, -0.3309391966460442], 29: [-1.3650000000000002, -1.31232718557848]}
{0: [-0.32000000000000006, -0.653542374142557], 1: [0.158333333333333, -0.2714030460292082], 2: [0.146666666666667, -0.33264475512244773], 3: [-0.09666666666666657, -0.608089224435045], 4: [0.01999999999999935, -0.28271181078098034], 5: [0.19499999999999973, -0.26360972799223414], 6: [0.39, 0.042357366797513896], 7: [-0.001666666666666372, -0.37713313038552293], 8: [0.23166666666666597, -0.012300172828700662], 9: [0.46333333333333293, -0.015624334750414748], 10: [0.29333333333333267, -0.020972584535922323], 11: [0.4283333333333333, 0.03879193931743552], 12: [0.2733333333333339, 0.034993481802166715], 13: [0.12166666666666626, -0.2528210286433573], 14: [0.108333333333333, -0.3388323873441658], 15: [0.044999999999999596, 0.042937610237620716], 16: [-0.34833333333333333, -0.32118885436887973], 17: [-0.2933333333333335, -0.32601545665510745], 18: [0.021666666666666945, 0.04418413352410394], 19: [-0.3833333333333329, -0.3392892897720922], 20: [0.013333333333333641, 0.007584945071034688], 21: [-0.6566666666666668, -0.6409021640875149], 22: [-0.25833333333333336, -0.3075776448587303], 23: [0.010000000000000175, 0.012824767096997314], 24: [-0.34500000000000025, -0.3234729272537752], 25: [-0.29, -0.3241094656493432], 26: [0.04000000000000015, 0.019741118407550717], 27: [-0.3600000000000005, -0.32919762927483653], 28: [-0.3333333333333332, -0.3316056585582616], 29: [-1.3650000000000002, -1.31232718557848]}
### Saving results
np.save('data\group_value.npy',group_value)
[0.1666666666666663, -0.27690374883168134]
### Reading data value
group_value = np.load('data/group_value.npy',allow_pickle=True)
### Separating high value from low value
high_value_data=[]
low_value_data=[]
aver_value = np.array(group_value)
for i in range(30):
if aver_value[i,0] >= 0.000:
high_value_data.append(i)
print(aver_value[i][0]) #high data importance value
elif aver_value[i,0] <= -0.01:
low_value_data.append(i)
print(aver_value[i][0]) #adverse data importance value
### Showing each data number
print(high_value_data) #high value group number
print(low_value_data) #adverse value group number
0.017661111111111096 0.0056888888888888675 0.012161111111111065 0.005483333333333319 0.013494444444444429 0.012761111111111088 0.008844444444444417 0.014466666666666636 0.0030388888888888714 0.010111111111111092 0.010955555555555527 0.00829999999999999 0.002622222222222205 0.010738888888888874 0.008299999999999983 -0.01563888888888888 -0.016244444444444454 -0.021688888888888862 -0.01418888888888889 -0.01361666666666667 -0.012866666666666672 -0.013305555555555541 -0.02439444444444442 -0.026522222222222232 -0.022800000000000018 -0.03447777777777779 -0.023666666666666697 -0.026816666666666683 -0.020183333333333338 -0.024916666666666695 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] [15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
The result shows that corrupted groups are assigned negative value, whereas useful data groups are assigned high value
### Tranining with valuable datasets
ma=[]
va=[]
for m in range(200): # random 200 times retrain the model (different permutation of data feeding order)
data_i = r_train
y_1 = r_labels_train
valuable_macro =[accuracy(r_train,r_labels_train,0)[1]]
valuable_accuracy = [accuracy(r_train,r_labels_train,0)[0]]
for i in np.random.permutation(high_value_data):
data_i = np.concatenate((data_i,x[i]), axis=0)
y_1 = np.concatenate((y_1,y[i]), axis=0)
macro=accuracy(data_i,y_1,0)[1]
accu = accuracy(data_i,y_1,0)[0]
valuable_macro.append(macro)
valuable_accuracy.append(accu)
ma.append(valuable_macro)
va.append(valuable_accuracy)
l_ma=[]
l_va=[]
### Training with adverse-value data
for k in range(200): # random 200 times retrain the model (different permutation of data feeding order)
data_i = r_train
y_1 = r_labels_train
adverse_macro =[accuracy(r_train,r_labels_train,0)[1]]
adverse_accuracy = [accuracy(r_train,r_labels_train,0)[0]]
for i in np.random.permutation(low_value_data):
data_i = np.concatenate((data_i,x[i]), axis=0)
y_1 = np.concatenate((y_1,y[i]), axis=0)
macro=accuracy(data_i,y_1,0)[1]
accu = accuracy(data_i,y_1,0)[0]
adverse_macro.append(macro)
adverse_accuracy.append(accu)
l_ma.append(adverse_macro)
l_va.append(adverse_accuracy)
### Save
# np.save('data/ma1.npy',ma)
# np.save('data/va1.npy',va)
# np.save('data/l_ma1.npy',l_ma)
# np.save('data/l_va1.npy',l_va)
--------------------------------------------------------------------------- NameError Traceback (most recent call last) ~\AppData\Local\Temp/ipykernel_32332/1222360119.py in <module> 8 valuable_macro =[accuracy(r_train,r_labels_train,0)[1]] 9 valuable_accuracy = [accuracy(r_train,r_labels_train,0)[0]] ---> 10 for i in np.random.permutation(valuable_data): 11 data_i = np.concatenate((data_i,x[i]), axis=0) 12 y_1 = np.concatenate((y_1,y[i]), axis=0) NameError: name 'valuable_data' is not defined
### Reading
ma= np.load('data/ma.npy',allow_pickle=True)
va = np.load('data/va.npy',allow_pickle=True)
l_ma = np.load('data/l_ma.npy',allow_pickle=True)
l_va = np.load('data/l_va.npy',allow_pickle=True)
ma_va= np.average(ma,axis=0)
va_va= np.average(va,axis =0)
l_ma_va = np.average(l_ma,axis =0)
l_va_va = np.average(l_va,axis =0)
err_max_ma_va=[0]
err_min_ma_va =[0]
for i in range(15):
maxvalue=max(ma[:,i])
minvalue=min(ma[:,i])
uperror=maxvalue-ma_va[i]
lowerror=ma_va[i]-minvalue
err_max_ma_va.append(uperror)
err_min_ma_va.append(lowerror)
err_max_l_ma_va=[0]
err_min_l_ma_va =[0]
for i in range(15):
maxvalue=max(l_ma[:,i])
minvalue=min(l_ma[:,i])
uperror=maxvalue-l_ma_va[i]
lowerror=l_ma_va[i]-minvalue
err_max_l_ma_va.append(uperror)
err_min_l_ma_va.append(lowerror)
err_max_va_va=[0]
err_min_va_va =[0]
for i in range(15):
maxvalue=max(va[:,i])
minvalue=min(va[:,i])
uperror=maxvalue-va_va[i]
lowerror=va_va[i]-minvalue
err_max_va_va.append(uperror)
err_min_va_va.append(lowerror)
err_max_l_va_va=[0]
err_min_l_va_va =[0]
for i in range(15):
maxvalue=max(l_va[:,i])
minvalue=min(l_va[:,i])
uperror=maxvalue-l_va_va[i]
lowerror=l_va_va[i]-minvalue
err_max_l_va_va.append(uperror)
err_min_l_va_va.append(lowerror)
np.set_printoptions(precision=2)
fig1 = plt.figure()
ax = fig1.add_axes([0,0,1,1])
xdata =np.linspace(0,15000,16)
xdata2 =np.linspace(0,15000,16)
plt.rcParams["figure.figsize"] = (16,9)
# plt.xticks(x,values)
plt.xticks(xdata,size =30,rotation = 50)
yerr=[err_max_ma_va,err_min_ma_va]
ax.plot(xdata,ma_va,linewidth=4.0,label="High-value Macro-F1",color='C0',marker='o',ms=15,linestyle=':')
ax.errorbar(xdata, ma_va, yerr=yerr,fmt='o', linewidth=2, capsize=10,color='C0')
yerr1=[err_max_l_ma_va,err_min_l_ma_va]
ax.plot(xdata2,l_ma_va,linewidth=4.0,label ="Low-value Macro-F1",color = 'C0',marker='*',ms=25,linestyle=':')
ax.errorbar(xdata, l_ma_va, yerr=yerr1,fmt='o', linewidth=2, capsize=10,color = 'C0')
ax.set_ylabel('Macro-F1 score',fontsize = 45, color = 'C0')
ax.set_xlabel('Number of added data instances',fontsize = 40)
ax.legend(fontsize = 30,loc =[0.01,0.84])
# ax.set_title('Model performance',fontsize=35)
ysticks =np.arange(0.4,0.9,0.05)
plt.yticks(ysticks,size = 30)
ax.grid(axis='x')
ax1 = ax.twinx()
yerr2=[err_max_va_va,err_min_va_va]
ax1.plot(xdata,va_va,linewidth=4,label="High-value Accuracy",color='C1',marker='o',ms=15,linestyle='--')
ax1.errorbar(xdata, va_va, yerr=yerr2,fmt='o', linewidth=2, capsize=10,color = 'C1')
yerr3=[err_max_l_va_va,err_min_l_va_va]
ax1.plot(xdata2,l_va_va,linewidth=4,label ="Low-value Accuracy",color = 'C1',marker='*',ms=25,linestyle='--')
ax1.errorbar(xdata, l_va_va, yerr=yerr2,fmt='o', linewidth=2, capsize=10,color = 'C1')
ax1.set_ylabel('Accuracy', fontsize = 45, color = 'C1')
ax1.legend(fontsize = 30,loc =[0.01,0.02],handlelength=2)
ysticks =np.arange(0.70,0.9,0.05)
plt.yticks(ysticks,size = 30)
fig1.figure.savefig('./Downloads/mode2_performance.png',dpi=450,bbox_inches = 'tight')